{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hRWOI1nxutyx"
      },
      "source": [
        "# Overview\n",
        "This codelab will demonstrate how to build a LSTM model for MNIST recognition using keras \u0026 how to convert the model to TensorFlow Lite.\n",
        "\n",
        "---\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "tXzpJuM7zujk"
      },
      "outputs": [],
      "source": [
        "!pip install tf-nightly"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "LOE_xIJuvMOU"
      },
      "source": [
        "### Prerequisites\n",
        "We're going to override the environment variable `TF_ENABLE_CONTROL_FLOW_V2` since for TensorFlow Lite control flows."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Vpx_hISazpZJ"
      },
      "outputs": [],
      "source": [
        "# This is important!\n",
        "import os\n",
        "os.environ['TF_ENABLE_CONTROL_FLOW_V2'] = '1'\n",
        "\n",
        "import tensorflow as tf\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "R3Ku1Lx9vvfX"
      },
      "source": [
        "## Step 1 Build the MNIST LSTM model.\n",
        "\n",
        "Note we will be using **`tf.lite.experimental.nn.TFLiteLSTMCell`** \u0026 **`tf.lite.experimental.nn.dynamic_rnn`** in the tutorial.\n",
        "\n",
        "Also note here, we're not trying to build the model to be a real world application, but only demonstrates how to use TensorFlow lite. You can a build a much better model using CNN models.\n",
        "\n",
        "For more canonical lstm codelab, please see [here](https://github.com/keras-team/keras/blob/master/examples/imdb_lstm.py)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "wiYZoDlC5SEJ"
      },
      "outputs": [],
      "source": [
        "# Step 1: Build the MNIST LSTM model.\n",
        "def buildLstmLayer(inputs, num_layers, num_units):\n",
        "  \"\"\"Build the lstm layer.\n",
        "\n",
        "  Args:\n",
        "    inputs: The input data.\n",
        "    num_layers: How many LSTM layers do we want.\n",
        "    num_units: The unmber of hidden units in the LSTM cell.\n",
        "  \"\"\"\n",
        "  lstm_cells = []\n",
        "  for i in range(num_layers):\n",
        "    lstm_cells.append(\n",
        "        tf.lite.experimental.nn.TFLiteLSTMCell(\n",
        "            num_units, forget_bias=0, name='rnn{}'.format(i)))\n",
        "  lstm_layers = tf.keras.layers.StackedRNNCells(lstm_cells)\n",
        "  # Assume the input is sized as [batch, time, input_size], then we're going\n",
        "  # to transpose to be time-majored.\n",
        "  transposed_inputs = tf.transpose(\n",
        "      inputs, perm=[1, 0, 2])\n",
        "  outputs, _ = tf.lite.experimental.nn.dynamic_rnn(\n",
        "      lstm_layers,\n",
        "      transposed_inputs,\n",
        "      dtype='float32',\n",
        "      time_major=True)\n",
        "  unstacked_outputs = tf.unstack(outputs, axis=0)\n",
        "  return unstacked_outputs[-1]\n",
        "\n",
        "tf.reset_default_graph()\n",
        "model = tf.keras.models.Sequential([\n",
        "  tf.keras.layers.Input(shape=(28, 28), name='input'),\n",
        "  tf.keras.layers.Lambda(buildLstmLayer, arguments={'num_layers' : 2, 'num_units' : 64}),\n",
        "  tf.keras.layers.Flatten(),\n",
        "  tf.keras.layers.Dense(10, activation=tf.nn.softmax, name='output')\n",
        "])\n",
        "model.compile(optimizer='adam',\n",
        "              loss='sparse_categorical_crossentropy',\n",
        "              metrics=['accuracy'])\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Ff6X9gg_wk7K"
      },
      "source": [
        "## Step 2: Train \u0026 Evaluate the model.\n",
        "We will train the model using MNIST data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "23W41fiRPOmh"
      },
      "outputs": [],
      "source": [
        "# Step 2: Train \u0026 Evaluate the model.\n",
        "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
        "x_train, x_test = x_train / 255.0, x_test / 255.0\n",
        "\n",
        "# Cast x_train \u0026 x_test to float32.\n",
        "x_train = x_train.astype(np.float32)\n",
        "x_test = x_test.astype(np.float32)\n",
        "\n",
        "model.fit(x_train, y_train, epochs=5)\n",
        "model.evaluate(x_test, y_test)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "NtPJGiIQw0nM"
      },
      "source": [
        "## Step 3: Convert the Keras model to TensorFlow Lite model.\n",
        "\n",
        "Note here: we just convert to TensorFlow Lite model as usual."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Tbuu_8PFz-x_"
      },
      "outputs": [],
      "source": [
        "# Step 3: Convert the Keras model to TensorFlow Lite model.\n",
        "sess = tf.keras.backend.get_session()\n",
        "input_tensor = sess.graph.get_tensor_by_name('input:0')\n",
        "output_tensor = sess.graph.get_tensor_by_name('output/Softmax:0')\n",
        "converter = tf.lite.TFLiteConverter.from_session(\n",
        "    sess, [input_tensor], [output_tensor])\n",
        "tflite = converter.convert()\n",
        "print('Model converted successfully!')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5rHrZkIuxxar"
      },
      "source": [
        "## Step 4: Check the converted TensorFlow Lite model.\n",
        "\n",
        "We're just going to load the TensorFlow Lite model and use the TensorFlow Lite python interpreter to verify the results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "8lao097MnFf2"
      },
      "outputs": [],
      "source": [
        "# Step 4: Check the converted TensorFlow Lite model.\n",
        "interpreter = tf.lite.Interpreter(model_content=tflite)\n",
        "\n",
        "try:\n",
        "  interpreter.allocate_tensors()\n",
        "except ValueError:\n",
        "  assert False\n",
        "\n",
        "MINI_BATCH_SIZE = 1\n",
        "correct_case = 0\n",
        "for i in range(len(x_test)):\n",
        "  input_index = (interpreter.get_input_details()[0]['index'])\n",
        "  interpreter.set_tensor(input_index, x_test[i * MINI_BATCH_SIZE: (i + 1) * MINI_BATCH_SIZE])\n",
        "  interpreter.invoke()\n",
        "  output_index = (interpreter.get_output_details()[0]['index'])\n",
        "  result = interpreter.get_tensor(output_index)\n",
        "  # Reset all variables so it will not pollute other inferences.\n",
        "  interpreter.reset_all_variables()\n",
        "  # Evaluate.\n",
        "  prediction = np.argmax(result)\n",
        "  if prediction == y_test[i]:\n",
        "    correct_case += 1\n",
        "\n",
        "print('TensorFlow Lite Evaluation result is {}'.format(correct_case * 1.0 / len(x_test)))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "TensorFlowLite_LSTM_Keras_Tutorial.ipynb",
      "provenance": [],
      "version": "0.3.2"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
